{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  9,  12,  15],\n",
      "        [ 18,  21,  24]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "x= torch.tensor([[1,2,3],[4,5,6]])\n",
    "y= torch.tensor([[7,8,9], [10,11,12]])\n",
    "f= 2*x + y\n",
    "print(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.,  0.,  0.],\n",
      "        [ 0.,  0.,  0.]])\n",
      "tensor([[ 1.,  1.,  1.],\n",
      "        [ 1.,  1.,  1.]])\n",
      "tensor([[ 0.9663,  0.0252,  0.5935],\n",
      "        [ 0.0047,  0.9107,  0.7989]])\n"
     ]
    }
   ],
   "source": [
    "shape=[2,3]\n",
    "xzeros =torch.zeros(shape)\n",
    "xones = torch.ones(shape)\n",
    "xrnd = torch.rand(shape)\n",
    "print(xzeros)\n",
    "print(xones)\n",
    "print(xrnd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.8823,  0.9150,  0.3829],\n",
      "        [ 0.9593,  0.3904,  0.6009]])\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "print(torch.rand([2,3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  8,  10,  12],\n",
      "        [ 14,  16,  18]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'torch.LongTensor'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "xnp= np.array([[1,2,3],[4,5,6]])\n",
    "f2= xnp + y\n",
    "print(f2)\n",
    "f2.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.LongTensor\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "numpy.ndarray"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(f.type()) #call the tensors type method\n",
    "fnp = f.numpy() #create an array from the tensor\n",
    "type(fnp) #uses the python inbuilt type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1,  2,  3],\n",
      "        [ 4,  5,  6]])\n",
      "torch.LongTensor\n"
     ]
    }
   ],
   "source": [
    "xtensor = torch.from_numpy(xnp)\n",
    "print(xtensor)\n",
    "print(xtensor.type())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "tensor([ 1.,  0.,  1.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "a = np.ones(3)                  \n",
    "t = torch.from_numpy(a)      #create a tensor from an array\n",
    "b = t.numpy()                #Create an array from the tensor\n",
    "b[1] = 0                     #change a value in the array\n",
    "print(a[1] == b[1])          #this value changes in the original array\n",
    "print(t)                     # and also in the tensor - they share the same memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "can't convert np.ndarray of type numpy.int8. The only supported types are: double, float, float16, int64, int32, and uint8.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-a7d699515564>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mint8np\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mbad\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint8np\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.int8. The only supported types are: double, float, float16, int64, int32, and uint8."
     ]
    }
   ],
   "source": [
    "int8np=np.ones((2,3),dtype=np.int8)\n",
    "bad= torch.from_numpy(int8np)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.IntTensor'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "good = torch.from_numpy(int8np.astype(np.int32))\n",
    "good.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.IntTensor'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xint=torch.ones((2,3), dtype=torch.int)\n",
    "xint.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 1,  2,  3])\n",
      "tensor([ 4,  5])\n"
     ]
    }
   ],
   "source": [
    "print(x[0])\n",
    "print(x[1][0:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3])\n",
      "tensor([ 1,  2,  3,  4,  5,  6])\n",
      "tensor([[ 1,  2],\n",
      "        [ 3,  4],\n",
      "        [ 5,  6]])\n",
      "tensor([[ 1],\n",
      "        [ 2],\n",
      "        [ 3],\n",
      "        [ 4],\n",
      "        [ 5],\n",
      "        [ 6]])\n"
     ]
    }
   ],
   "source": [
    "print(x.size())\n",
    "print(x.view(-1))\n",
    "print(x.view(3,2))\n",
    "print(x.view(6,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1,  2],\n",
      "        [ 3,  4],\n",
      "        [ 5,  6]])\n"
     ]
    }
   ],
   "source": [
    "print(x.view(3,-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 2])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(x.transpose(0,1).size())\n",
    "x.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 3, 2, 1])\n",
      "torch.Size([4, 3, 2, 1])\n"
     ]
    }
   ],
   "source": [
    "a = torch.ones(1,2,3,4)\n",
    "print(a.transpose(0,3).transpose(1,2).size())    #swaps axis in two steps\n",
    "print(a.permute(3,2,1,0).size())                 #swaps all axis at once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1,  2,  3],\n",
      "        [ 4,  5,  6]])\n",
      "tensor([[ 1,  4],\n",
      "        [ 2,  5],\n",
      "        [ 3,  6]])\n"
     ]
    }
   ],
   "source": [
    "print(x)\n",
    "x.transpose_(1,0)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset CIFAR10\n",
       "    Number of datapoints: 50000\n",
       "    Split: train\n",
       "    Root Location: ./data\n",
       "    Transforms (if any): ToTensor()\n",
       "    Target Transforms (if any): None"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data',   #data root directory \n",
    "                                             #The training set\n",
    "                                        download=True,   #checks if data has downloaded else does so\n",
    "                                        transform = transforms.ToTensor())  #transforms to tensor\n",
    "trainset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size of image torch.Size([3, 32, 32]) label 6\n",
      "size of image torch.Size([3, 32, 32]) label 9\n",
      "size of image torch.Size([3, 32, 32]) label 9\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(trainset)):\n",
    "    print('size of image {} label {}'.format(trainset[i][0].size() , trainset[i][1]))\n",
    "    if i>1: break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f7ed98e9588>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAH3FJREFUeJztnVuMXNd1pv9Vt67qezf7QrJJiRJ1\nGcmxRMmMIEiZjB3PBIoRRDaQZOwHQw9GGAQxEAPJg+AAYw8wD/ZgbMMPAw/okRJl4PFlfImFQJjE\nEWwIiQNFlCXrHomiKLHJVrPJ7mZ3dVXXdc1DlyZUa/+bJTZZTWn/H0B0ca/a56zaddY5VeevtZa5\nO4QQ6ZHZbgeEENuDgl+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlEU/EIkSm4rk83sHgBf\nB5AF8D/d/Uux5+fzee8rFoO2VqtF52UQ/hVi1vi+Cjl+XstHbLlsltrMwjs0i5xDIz42m/w1x353\nmY35SH6x2fY231eb780ykRcQod0Ov7aY79HtRfy3yCIzWybiRzbD3092DABAO/JrWY8dCGxOdHth\nFpdXUa6sd7Wziw5+M8sC+O8A/gOAWQBPmNnD7v4Cm9NXLOLA7R8K2paXF+m++jLhN368wBfnqh39\n1DY5PkBtE6OD1FbI5oPjub4SnYMsX+LFpWVqqzf5axsbHaG2TKsRHK/VanTO+vo6tRVL4ZM1ALTA\nT16Vajk4PjI6TOfA+fbqtTq1ZRF+XwB+shka5O/zwAA/PvJ5vh7ViI8eu0BkwsdI7DU3PRzfX37g\nB3w/m3fb9TPfyR0Ajrr7MXevA/gOgHu3sD0hRA/ZSvDPADhx3v9nO2NCiPcAW/nOH/rc8Y7PqmZ2\nCMAhAOjr69vC7oQQl5KtXPlnAew97/97AJza/CR3P+zuB939YC7Pv5sJIXrLVoL/CQDXm9k1ZlYA\n8EkAD18at4QQl5uL/tjv7k0z+yyAv8WG1Peguz8fm7O+vo7nXwg/ZfnMGTpvnNxgtR38zutEa4ja\nrDRFbWttrjqUW+E78G4FOqeyzu/YVqr8DnyjxaWtMxGNs5gL+9hs8u1lyd1mIP5VrbK+Rm3Ndvh1\n2/oOOicTUQEbEbWilOPHQZncMV9sNemc/n5+t98y/NOrETUIABCRDyvrYYWm2QiPA0A2F35fGutV\n7sMmtqTzu/sjAB7ZyjaEENuDfuEnRKIo+IVIFAW/EImi4BciURT8QiTKlu72v1syAEo5IlNFfvx3\nNZH09k3zBJepyXFqK8WknEjWVrUWToBZb3AZyiPbK5QiCUGRxB5v8/2NjIcTmpoNvr1CnvsRSbZE\ntsDftFo9vFaNJl+P/sj2cgPcx2JkXtPCcmQmkiXYjGTgxTJJBwd4Mll5rUJtjWZY0oslVK6unAuO\nt2Nv2Obtd/1MIcT7CgW/EImi4BciURT8QiSKgl+IROnp3X4zR9HCCRVDQ9yVG2bGguM7SjwTJN/m\npanKizzZptXm58NqJex7huf1YDhSFiwXuUu9fG6Vz4u8a+ND4TvOqys8CaceSdCpkqQTIF6XbpCU\nwmrUeeJJpsVfWD6SYNQipcsAIEduz9dqfE4hz9/QTJsnBNXKS9QGkhQGAH3kMG62uSJxbi2s+LQi\n9Rg3oyu/EImi4BciURT8QiSKgl+IRFHwC5EoCn4hEqWnUl/ODGN94V2WIlLOCEnqmBzmNdNapF0U\ngEifGSCbixSSI3XYau2I1BTR5XKR5JJWjUtinuXn7NOnw12AWg3+qlcrPOmk0uKy6GAp0n2nRtp1\ngb/mjHGZKtsX6ZSzxmXd/nzYx1ykFdZ6pO5itcGlvnakydpymfu4XAkfP2UiLQPAeiN8DNQjtRo3\noyu/EImi4BciURT8QiSKgl+IRFHwC5EoCn4hEmVLUp+ZHQewig31rOnuB6M7yxomR8OSzVCeS2zF\nYtiWyXJppRSpj9doctmrHclUcw9LQPVIvb1WncuAbY9kzEUkNs/xrLPVejhDr9Xi61uJtAZrRmyr\na9z/k4thP/IZvr3hMl/7xpu8nVv1HJcqr5q4Ljg+NbWHzrGhcH08AKgtnaW2cplnR55b5VLfmXNh\nWff4Ce5HKxsO3Vqdy4ObuRQ6/0fcnb8zQogrEn3sFyJRthr8DuDvzOxJMzt0KRwSQvSGrX7sv9vd\nT5nZFICfmNlL7v7Y+U/onBQOAUAx8r1eCNFbtnTld/dTnb+nAfwIwB2B5xx294PufrCQ07cMIa4U\nLjoazWzAzIbeegzgNwE8d6kcE0JcXrbysX8awI867a1yAP63u//f2IR8Lovdk+HCjsMFLlEM9oel\nLYtIZYhkWFkkm65W5bJRhsiAO4Z427CBAZ6NtnKOiyQjwzxjbjVSVPP1k+Ftlmv8K1chkgg20x/J\nSszzzMPjZ8PZhTWPFF2NZPWNDA9R2103c4V5ZS4s63olsq8Jni1aq/D1KJf5tbQvz7e5d2f4tU1N\nTdM58yth6fDsy2/SOZu56OB392MAbr3Y+UKI7UVfwoVIFAW/EImi4BciURT8QiSKgl+IROltAc+s\nYXwonG2Xq4elIQDoy4fd7O8L96UDgFqVy2GNSL+10dFwX0AAcFL0sd7i59BGI1JccpD38Tu1EO7F\nBgCvvs6zvRZWw68tUgsSV0d6Hn783x6gtj27uP/ff/JYcPyfjnIpqtnmmYy5DJfmVpcXqK1SDq/j\n0BCX3tDi2YXFIp9XINmnANBvfF6zFX5zrtq7m84ZWgz3cnzmNb4Wm9GVX4hEUfALkSgKfiESRcEv\nRKIo+IVIlN7e7c/lMDW+I2irLvK74hkLu1kmbY4AoBqpZZazSD27SFsrdqasNvhd6tExnqBTb/E7\n2MdmT1Hb4gr3kdX3y0ZafA0X+famcuG7ygBQXOSKxPXDO4Pjc+Pcj/nl09RWq/A1furll6ktQ9pX\nNQYircZGeEINMjxkRka4+jTUjrQHI3Uevb5C5+wjCXJ9+e6v57ryC5EoCn4hEkXBL0SiKPiFSBQF\nvxCJouAXIlF6LPXlMTYxGbSNDfL2WplMOClieWWJzmmslfn2WrF2XbygnZMEo8FBXqevAW578RiX\nqNZqvPVTsdjHbYWwj6UBLkONZbks+uTReWpr1vnhUxsJS32TY3w9DFx+azS5FFyp81qCa6RWX73J\nX7NFpNtINzfkM5FWb5lI7cJceB2bNS6lOpGJSe5ZEF35hUgUBb8QiaLgFyJRFPxCJIqCX4hEUfAL\nkSgXlPrM7EEAvw3gtLv/SmdsHMB3AewDcBzA77s7193+dWsAke0s0s6I0Repp9aPcNYTAOQi57xM\nJlKPj8iAfSXeruvMmzwrrnKGL9m141wSq3HVC0Ui6d24f4bOyUQ22MzyNV6JSK25bLjO4FCBvy87\nxvZT2/7rr6K21954gtpeevlkcLyQi8hozmXiZpOHTIZkVAJAvsDXsd0OH1ftiK5oFj5OI0rkO+jm\nyv+XAO7ZNHY/gEfd/XoAj3b+L4R4D3HB4Hf3xwAsbhq+F8BDnccPAfj4JfZLCHGZudjv/NPuPgcA\nnb9Tl84lIUQvuOw3/MzskJkdMbMjq5XIl1UhRE+52OCfN7NdAND5S+svufthdz/o7geH+vlNLCFE\nb7nY4H8YwH2dx/cB+PGlcUcI0Su6kfq+DeDDACbMbBbAFwB8CcD3zOwzAN4A8Hvd7Kztjup6uFih\nNXhmFhDOwFpb4wUO6w1+Xmtm+CeQcoVLcyvENrOXL6M3+faunuDCzP7dXBqqrPN5MzfcGhwvOP/K\ntXSOF0ItjYYLrgIAzvJMtb07dwXHl9d4tuK1/+Z6ahse41mJw2M3UdvSQnj9l87xlmf5iByZcZ5R\n2WhHskV5sihajfDxHUkSpK3j3kVS34WD390/RUwffRf7EUJcYegXfkIkioJfiERR8AuRKAp+IRJF\nwS9EovS0gKfD0bKwHOItXlCRyRqlIi/6OTjEpaFTC1xWfG12gdpy+bAfhXneV299nm/v+iku5330\nw1z2evXk5lSLf2VoJlwgdWJHuKAmAJxe4EU6R0cjsleb+18gBStPL4Sz7AAgV1ymtoXlOWo7Ocez\n8PL58HEwOsy1t2qVC2ae49dLi2hz7YgMmLHwPItkmEbaPHaNrvxCJIqCX4hEUfALkSgKfiESRcEv\nRKIo+IVIlJ5KfdlsBqOjg0FbM8elvnI5nJHmDS6fnFvlWVuvv8GlrXKZy0alYvhcOfcazy6cLvKi\njjMzV1Pb6O5rqC2/GkkRI0VN99x6B5/yJpffSk0uVbbAMwXX1sK2Xf1hKRIA6i3+umwgfNwAwJ6B\n3dQ2NBqWOFfPvknnnJ4/S20N4/Lmep0XBUWGa3MDfeEs03o1ImGSgqBGZMOgS10/UwjxvkLBL0Si\nKPiFSBQFvxCJouAXIlF6ere/3WpidTl8JzVX57Xu8qQ1EXgJOeSy3FgpcyVgbIgnsowOhO/KVpf4\n3f6p3bwG3swt/47anputU9vLR7ntrl3jwfHlZT5nen+47h8AZFChtnqNKwGjHr5zv3Ka30kv1Xkt\nwV3j4dcFAMstXlcvf8tYcLwaSRT6x0ceprbZE/w1ZyMtuWKNtFgeUSPWVq4RXiuWBBfcRtfPFEK8\nr1DwC5EoCn4hEkXBL0SiKPiFSBQFvxCJ0k27rgcB/DaA0+7+K52xLwL4AwBv6R6fd/dHutlhlige\nrUgSgxOZJEPaeAFAy7jUt8QVJaysROq31cJy2a4RLg/+6kc+Qm17bryT2n74Fw9S285Ikku2Hq5P\nePLYq3x7195MbcUd11HbgHN5trIY7t1aaoelNwCoV7mseGaV20YneRLUjp37guPV8jCdk+EmtAo8\nmSlWw6/R4FKrNcMJauY8ca3ZDIfupZb6/hLAPYHxr7n7gc6/rgJfCHHlcMHgd/fHAPBysUKI9yRb\n+c7/WTN7xsweNDP+WU4IcUVyscH/DQD7ARwAMAfgK+yJZnbIzI6Y2ZFyhX/vEUL0losKfnefd/eW\nu7cBfBMALRPj7ofd/aC7Hxzs51VthBC95aKC38x2nfffTwB47tK4I4ToFd1Ifd8G8GEAE2Y2C+AL\nAD5sZgcAOIDjAP6wm50ZACNKRItkKQG8bVGkcxK8GtlepATe+A7e5mtnf1havP3gDXTOTXdxOW/p\nNJc3+5o88/DaPXuorU1e3M4pXjuvuc4l00okG7De5PMa1fCh1QKXKV89OUttzz53hNruupP7uGNn\nOKtyZTUsRQIA6fAFAJjYx2Xddqy9Vj0i2xEJ+dwCb19WWw072SbZlCEuGPzu/qnA8ANd70EIcUWi\nX/gJkSgKfiESRcEvRKIo+IVIFAW/EInS0wKe7kCbZDBVa1yiKJAstlyOF0zMZrj8c91O/mvkYomf\nD/ddvTc4fuuv8cy9XTfeQm1P/9NfUNtVe7mPOz/wQWorTO4Pjuf6R+icyjqXHKsrPHNv/tQJalua\nD8t2rQbPzisNhQukAsDEBH+vT5x6itqmd80Ex5uVSBZplbfdsrUlamt5OKMSAJxp3ABKfeHXVtjJ\nX/NKH8l0fRcRrSu/EImi4BciURT8QiSKgl+IRFHwC5EoCn4hEqWnUp+ZIZ8N73IpUqCxtR6WNUr9\nJTonm+HSylQkc+/EHM+k2n97qJQhsOeD4fENuGTXWF2jtpEhLs1N3nCA2tZy4Z52zz/1BJ1Tq3I/\nVlb4epw5+Qa1ZVthqbVY5IfczDVhWQ4AbrmBFxJtZnmmXT47Gh4v8KzP3Dov0ll5/SS1MRkbAJqR\ny2yZ9JXs38Ff1zTpAZnPd38915VfiERR8AuRKAp+IRJFwS9Eoij4hUiU3ib2tNuoVcN3Uvv7uCtW\nDN8NzWd4DTlvcVtpkLfy+p3/+DvUdtdvfTQ4PjwxTefMH3uR2rIR/5dXeQ2/heP/Qm2nVsN3nH/2\n139N5wyWeALJeo0nwOyc5orE8FD4TvVrszwZqB5Zj/Hd+6jthg9+iNrQ6gsOLy7zeoEVoi4BwFKV\n+2jOj+H1Kk9cK5MWW17mqsNNYRED7e67denKL0SqKPiFSBQFvxCJouAXIlEU/EIkioJfiETppl3X\nXgB/BWAngDaAw+7+dTMbB/BdAPuw0bLr992dFzgD4HC0ndTWa/OkCGuGZZKmR1pyRWqmFfuGqe3A\nh7hs1JcPS2IvPM1ryC2depXaajUu5awuLVLbiaMvUFvZw8lO+Rbf12COS5/DRZ5cMjnGpb65+TeD\n481IW7bKKpcVT7zGk4iA56mlXA7XICzm+PHR7JuitrNNfuyUSrwGYf8QT0Ir5cJy5Gplhc5ptsOS\n47tQ+rq68jcB/Km73wTgTgB/bGY3A7gfwKPufj2ARzv/F0K8R7hg8Lv7nLv/ovN4FcCLAGYA3Avg\noc7THgLw8cvlpBDi0vOuvvOb2T4AtwF4HMC0u88BGycIAPyzkhDiiqPr4DezQQA/APA5d+dfRt45\n75CZHTGzI2tVXktfCNFbugp+M8tjI/C/5e4/7AzPm9mujn0XgGDDc3c/7O4H3f3gQKlwKXwWQlwC\nLhj8ZmYAHgDwort/9TzTwwDu6zy+D8CPL717QojLRTdZfXcD+DSAZ83s6c7Y5wF8CcD3zOwzAN4A\n8HsX3pRjQy18J+0m/0qQy4dr7rUiNdPq4NlX0yO8rt7fPvw31DY+HZaUpnaF23gBQL3Cs/Py+bDE\nAwCDA1xSymW4NDdA5MidU+GabwBQXeUKbSnLfTy7cIbaGvXwezNU5JJXvcylvleeOkJtcy+9TG21\nJmmhledr2Iqt7x4ufWKAH8OZPi61FolsNwa+Vjd94JrgeKl4jM7ZzAWD393/AQDLcQznuAohrnj0\nCz8hEkXBL0SiKPiFSBQFvxCJouAXIlF6WsATbmi3w8JBIZJZVsyR4ocZXmjRIy2c2nWeWXbmTDgb\nDQDKC2FbqcF/8NgGf13jY1x+G909SW3NVo3aTp4K++iRfK9Mhh8G9SaXTLPGC38OFMPyLEnQ3Nhe\nzBjJ0mzVuZyaIcfbSoXLm/U+Ig8CGNrN136txFubrba5DLi+Fr4G7xi+ls6ZINJtLt99SOvKL0Si\nKPiFSBQFvxCJouAXIlEU/EIkioJfiETprdQHQ8bCWWLFPp7B5CRDb6AUlpMAYGBogtoqDZ5htWOI\n1xzIET/q5+bpnHaGb6+S59LW9HQ4awsA2nUuG914y57g+M9/+iidU/cKteWNy6nVMp83PBTOSizk\n+CGXtUg/u3X+nr02x2W75eXwe1azNTpn8gZ+TZwZjWQlOn+vl87wtSqshyXTgZlIJmYlnDXZjqil\nm9GVX4hEUfALkSgKfiESRcEvRKIo+IVIlJ7e7c8YUMiFzzeVGk+YyJKWUe1IfblKgydnZPM8SaSv\nwO/m5vNhPwr9vG3VyDBPMHpzgasElZnwXXsAmNp7HbWdPB2uq/eBX72bzikvnKK2Yy/zVlhrZZ7I\nksuG139khNcmNFLfEQDmTnIf33g9ktjTF17/4WmuFE2OR3yMqA62yN/rsSUeajNT48HxPaP8GDj6\nQjiBq1blSWub0ZVfiERR8AuRKAp+IRJFwS9Eoij4hUgUBb8QiXJBqc/M9gL4KwA7sdFr67C7f93M\nvgjgDwAsdJ76eXd/JLqznGF6Mny+aZw9S+dVW2EJaI3nZsAzvJVXLpJcMjzMkykKpBVWdY3X8CvF\naqrVue3Iz39ObdfeyCXC2dmwBJSJ1Dvs7+O1+LIRObVU4tLWWjks9VWrXIJtRlq2DZa4H3fddgO1\nFUmCUTPLaxO2GjwJp3qCS32Z1SK1TfUPUdttN3wgPGd0ms55cu614HizwV/XZrrR+ZsA/tTdf2Fm\nQwCeNLOfdGxfc/f/1vXehBBXDN306psDMNd5vGpmLwKYudyOCSEuL+/qO7+Z7QNwG4DHO0OfNbNn\nzOxBM+Otb4UQVxxdB7+ZDQL4AYDPufsKgG8A2A/gADY+GXyFzDtkZkfM7MhKhX+nE0L0lq6C38zy\n2Aj8b7n7DwHA3efdveXubQDfBHBHaK67H3b3g+5+cLifVzoRQvSWCwa/mRmABwC86O5fPW9813lP\n+wSA5y69e0KIy0U3d/vvBvBpAM+a2dOdsc8D+JSZHQDgAI4D+MMLbahQMFy1N3z1HzEukxw9EZZe\n5hd4dl69xaWhwUH+stcqPEOs1S4Hx7ORc+jiApcwV8tclllvcD+yzm1Dg+FbL/NvLtI5s2tcvmo7\nlwinJ7ksau1wdtnSMq+31zfA37PRES6VFbJ8/Wt1IvnmuLy5VuPbq5cjLcrafN51e3dS2+6d4XU8\nMcsl3bML4ZhoxlqebaKbu/3/ACB0BEQ1fSHElY1+4SdEoij4hUgUBb8QiaLgFyJRFPxCJEpPC3hm\nc4bhMZIZR6QLABibyoYNA7wI45l5XhB0PdLuKlfgxRvZtHaDZxA2WtyPc1Uuew1EstjWK1yaq66H\nC3jWIz62IjZ3svYAyiuRdl3D4UKow8O82Gm1yrd35ixfq8FBnl1omfD1zZpcJi7keBHXPq5Io1Dg\na7Xvun3UVq2EfXnssRfonGdePh3e1nr3WX268guRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJReir1\nmRlyxfAui8M81398MHyOylW5jJYv8eymlUjfNLT4+bBUnApPyfN9tWq8n12hn/uRz/H1yGa5xFnz\nsC/1Bpc3PZK5Z1wRg9e55Ngipnwkmw4FLm8uL3Gpr1rn/elGRsPSbY5IgACQiax9BVxKmz+zSm1L\nkQzO1bVwlubf/+wlvi+iiq7XJfUJIS6Agl+IRFHwC5EoCn4hEkXBL0SiKPiFSJSeSn3ttqHMCiBm\nB+m8wYGwbpQvcR1qIJJ+NTLCpbnyCu8lV14JF1QsVyJZfevcNlTgBTCLpC8gADRrXOLM5cLn80Lk\nNJ/v49loZnxif6QQaoaYmi0uRRVKkR6Ko1zeXFzkEtsqkT6Hx/naVyI9A185zguyvvTsCWqbHufZ\notN7yGvL8ON0ghQ0nV/lsuc7Nt/1M4UQ7ysU/EIkioJfiERR8AuRKAp+IRLlgnf7zawI4DEAfZ3n\nf9/dv2Bm1wD4DoBxAL8A8Gl3j7bhrdeB2dfDttoyvzs/NBm+Q1wsRRI6uHiA8XH+sstrvI7c8nLY\ntnSWJ4Is8ZvDyLb5Xfa2cyWj1eIKAtphW+wsbxme2JPN8bWqRpKgnNzUz5M2XgDQrPCWYq1Ifb9W\nJFlouRyex7p4AcBiRPE5fpS/octn16itvsZ3uHMk3Mrrpqtn6Bzm4itvrtA5m+nmyl8D8Bvufis2\n2nHfY2Z3AvgygK+5+/UAlgB8puu9CiG2nQsGv2/wVofKfOefA/gNAN/vjD8E4OOXxUMhxGWhq+/8\nZpbtdOg9DeAnAF4FsOz+/z/czQLgn1GEEFccXQW/u7fc/QCAPQDuAHBT6GmhuWZ2yMyOmNmRc2Ve\n/EEI0Vve1d1+d18G8DMAdwIYNbO37gbtAXCKzDns7gfd/eDIYKTjgRCip1ww+M1s0sxGO49LAP49\ngBcB/BTA73aedh+AH18uJ4UQl55uEnt2AXjIzLLYOFl8z93/xsxeAPAdM/svAJ4C8MCFNuSWQys/\nEbQ1CgfpvFo7nMiSaYZbUwFAcYTLV6OT/BPIWIYnnoxXwokWy4u8vdPyGS7nVdf48reaXD6E83N2\nuxn2cb3Kv3IVCpF6gTnu/+o6Tzypkq94+YgaPJQJJ6sAQDvDJaxGg69j30BYMi3meb3A0QL38VqM\nUtsHb+Vtw2685VZq23fddcHxO+7k8ubsqXJw/B9f5TGxmQsGv7s/A+C2wPgxbHz/F0K8B9Ev/IRI\nFAW/EImi4BciURT8QiSKgl+IRDGPZI9d8p2ZLQB4K69vAkD3usTlQ368Hfnxdt5rflzt7pPdbLCn\nwf+2HZsdcXcu7ssP+SE/Lqsf+tgvRKIo+IVIlO0M/sPbuO/zkR9vR368nfetH9v2nV8Isb3oY78Q\nibItwW9m95jZv5jZUTO7fzt86Phx3MyeNbOnzexID/f7oJmdNrPnzhsbN7OfmNkrnb9j2+THF83s\nZGdNnjazj/XAj71m9lMze9HMnjezP+mM93RNIn70dE3MrGhm/2xmv+z48Z8749eY2eOd9fiumUVS\nP7vA3Xv6D0AWG2XArgVQAPBLADf32o+OL8cBTGzDfn8dwO0Anjtv7L8CuL/z+H4AX94mP74I4M96\nvB67ANzeeTwE4GUAN/d6TSJ+9HRNABiAwc7jPIDHsVFA53sAPtkZ/x8A/mgr+9mOK/8dAI66+zHf\nKPX9HQD3boMf24a7PwZgc53qe7FRCBXoUUFU4kfPcfc5d/9F5/EqNorFzKDHaxLxo6f4Bpe9aO52\nBP8MgPPbmW5n8U8H8Hdm9qSZHdomH95i2t3ngI2DEMDUNvryWTN7pvO14LJ//TgfM9uHjfoRj2Mb\n12STH0CP16QXRXO3I/hDJXa2S3K4291vB/BbAP7YzH59m/y4kvgGgP3Y6NEwB+ArvdqxmQ0C+AGA\nz7l7990nLr8fPV8T30LR3G7ZjuCfBbD3vP/T4p+XG3c/1fl7GsCPsL2ViebNbBcAdP6e3g4n3H2+\nc+C1AXwTPVoTM8tjI+C+5e4/7Az3fE1CfmzXmnT2/a6L5nbLdgT/EwCu79y5LAD4JICHe+2EmQ2Y\n2dBbjwH8JoDn4rMuKw9joxAqsI0FUd8Ktg6fQA/WxMwMGzUgX3T3r55n6umaMD96vSY9K5rbqzuY\nm+5mfgwbd1JfBfDn2+TDtdhQGn4J4Ple+gHg29j4+NjAxiehzwDYAeBRAK90/o5vkx//C8CzAJ7B\nRvDt6oEfv4aNj7DPAHi68+9jvV6TiB89XRMAt2CjKO4z2DjR/Kfzjtl/BnAUwP8B0LeV/egXfkIk\nin7hJ0SiKPiFSBQFvxCJouAXIlEU/EIkioJfiERR8AuRKAp+IRLl/wHCOW2RBgdIrQAAAABJRU5E\nrkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f7ee16b1908>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "torchimage=trainset[0][0]    #Indexes the first element of the first tuple ie the 1st image\n",
    "npimage=torchimage.permute(1, 2, 0) #changes the axis C H W to H W C\n",
    "plt.imshow(npimage)    #plots the image - no need to convert to numpy "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "import torch\n",
    "import csv\n",
    "import os\n",
    "\n",
    "class toyDataset(Dataset):\n",
    " \n",
    "    def __init__(self, dataPath, labelsFile,  transform=None, train =True):       \n",
    "               \n",
    "        self.dataPath=dataPath     #the path to the data directory\n",
    "        self.transform = transform  #a transform object\n",
    "        self.train = train\n",
    "        \n",
    "        #builds a list of (name,label) tuples        \n",
    "        with open(os.path.join(self.dataPath,labelsFile)) as f:\n",
    "            self.labels=[tuple(line) for line in csv.reader(f)]        \n",
    "        # checks that all images files exist\n",
    "        l=len(self.labels)\n",
    "        for i in range(l):\n",
    "             assert os.path.isfile(dataPath + '/' + self.labels[i][0])  \n",
    "        ind = int(0.9 * l) \n",
    "        self.trainset, self.testset = self.labels[:ind] , self.labels[ind:]\n",
    "    \n",
    "    # so we can use dataset.len()\n",
    "    def __len__(self):\n",
    "        if self.train:\n",
    "            return len(self.trainset)\n",
    "        else:\n",
    "            return len(self.testset)\n",
    "    \n",
    "    #so we can use indexing\n",
    "    def __getitem__(self, idx):\n",
    "        if self.train:\n",
    "            imageName,imageLabel=self.trainset[idx][0:]\n",
    "        else:\n",
    "             imageName,imageLabel=self.testset[idx][0:]\n",
    "            \n",
    "        imagePath =  os.path.join(self.dataPath,imageName)              \n",
    "        image = Image.open(open(imagePath, 'rb'))\n",
    "                   \n",
    "        #transforms the image if required\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        \n",
    "        return((image,imageLabel))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 551, 816])\n",
      " toy\n",
      "108\n"
     ]
    }
   ],
   "source": [
    "toydata = toyDataset('data/GiuseppeToys','labels.csv', transform=transforms.ToTensor())\n",
    "print(toydata[0][0].size()) #the size of the first image in the dataset\n",
    "print(toydata[0][1]) #the label\n",
    "print(len(toydata))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tforms= transforms.Compose([transforms.Grayscale(3), transforms.CenterCrop(300), transforms.ToTensor()])\n",
    "toyData=toyDataset('data/GiuseppeToys','labels.csv', transform =tforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(' toy', ' notoy ', ' notoy ', ' toy')\n",
      "torch.Size([4, 3, 300, 300])\n"
     ]
    }
   ],
   "source": [
    "toyloader = DataLoader(toyData, batch_size=4, shuffle=False)\n",
    "toyiter= iter(toyloader)\n",
    "images, labels = toyiter.next()\n",
    "labels=labels[0:]\n",
    "print(labels)                      \n",
    "print(images.size())        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 3,  2,  2,  2])\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets\n",
    "dataFromFolders = datasets.ImageFolder(root='data/GiuseppeToys/images', transform=tforms)\n",
    "folderloader = DataLoader(dataFromFolders, batch_size=4, shuffle=True)\n",
    "images, labels = iter(folderloader).next()\n",
    "print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RemoveChannel(object):\n",
    "    def __init__(self, color):\n",
    "        #color = 'r' 'g' or 'b' - determines which color to remove \n",
    "        self.color=color\n",
    "        \n",
    "    def __call__(self, image):\n",
    "        if self.color == 'r':            \n",
    "            image[0,:,:] = 0             \n",
    "        elif self.color == 'b':\n",
    "            image[1,:,:] = 0            \n",
    "        elif self.color == 'g':\n",
    "            image[2,:,:] = 0        \n",
    "        return(image)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "324"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cc2,cc3=RemoveChannel('b'), RemoveChannel('g')\n",
    "tforms2=transforms.Compose([transforms.CenterCrop(500), transforms.ToTensor(), cc2])\n",
    "tforms3=transforms.Compose([transforms.CenterCrop(500), transforms.ToTensor(), cc2])\n",
    "toydata2=toyDataset('data/GiuseppeToys','labels.csv', transform =tforms2, train=True)\n",
    "toydata3=toyDataset('data/GiuseppeToys','labels.csv', transform =tforms3, train=True)\n",
    "concatDataset= torch.utils.data.ConcatDataset([toydata,toydata2, toydata3])\n",
    "len(concatDataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
